#!/usr/bin/env python3

# Copyright (c) 2023 Arm Limited and Contributors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import atexit
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property, partial, lru_cache
import hashlib
import logging
import os
from os import PathLike
import shutil
import tempfile
import traceback
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union
import sys

import arpy

import lief.ELF as ELF
import lief


class Const:
    """Constants"""

    OBJECT_FILE_EXTS = {".o"}
    ARCHIVE_EXTS = {".a"}
    LOG_FORMAT = "%(message)s"
    HASHFN = "md5"
    EPILOG = """
Exit status:
    * 255 in the case of user or internal error, or
    * The number of ODR violations up to a maximum of 254, or
    * The exit status of the executed program, if any

Example:
    {0} foo.o bar.o
    Scan foo.o and bar.o, then exit.

Example:
    {0} -vv gcc -o foobar main.o foo.o bar.o
    Enable verbose logging, scan main.o, foo.o and bar.o, then execute the gcc command.
"""


class Error:
    ...


class Error(Enum):
    """Error codes"""

    ZeroSize = (
        100,
        "Symbol {symbol1} has size of 0 in {binary1} or {binary2}",
        logging.WARNING,
        True,
    )
    NoContent = (
        101,
        "Symbol {symbol1} has no content in {binary1} or {binary2}",
        logging.WARNING,
        True,
    )
    TypeMismatch = (
        102,
        "Symbol {symbol1} has type {value1} in {binary1} but {value2} in {binary2}",
        logging.ERROR,
        True,
    )
    SizeMismatch = (
        103,
        "Symbol {symbol1} has size {value1} in {binary1} but {value2} in {binary2}",
        logging.ERROR,
        True,
    )
    HashMismatch = (
        104,
        "Symbol {symbol1} has hash {value1} in {binary1} but {value2} in {binary2}",
        logging.ERROR,
        False,
    )

    @classmethod
    def from_code(cls, code: int, suppressable: bool = False) -> Error:
        """
        Get the Error with code 'code'
        :param suppressable: If True, ignore unsuppressable errors.
        :returns: The Error if there is a match.
        :raises ValueError: If there is no match.
        """
        errors = cls.get_suppressable() if suppressable else cls
        for e in errors:
            if e.code == code:
                return e
        raise ValueError(f"No error has code {code}")

    @classmethod
    def from_name(cls, name: str, suppressable: bool = False) -> Error:
        """
        Get the Error with name 'name'
        :param suppressable: If True, ignore unsuppressable errors.
        :returns: The Error if there is a match.
        :raises ValueError: If there is no match.
        """
        errors = cls.get_suppressable() if suppressable else cls
        for e in errors:
            if e.name == name:
                return e
        raise ValueError(f"No error has name {name}")

    @classmethod
    def get_suppressable(cls) -> Iterable[Error]:
        """Get an iterable of suppressable Errors"""
        return filter(lambda e: e.suppressable, cls)

    def __init__(self, code: int, message: str, severity: int, suppressable: bool):
        self._code = code
        self._message = message
        self._severity = severity
        self._suppressable = suppressable

    @property
    def code(self) -> int:
        return self._code

    @property
    def message(self) -> str:
        return self._message

    @property
    def severity(self) -> int:
        return self._severity

    @property
    def suppressable(self) -> bool:
        return self._suppressable

    def __str__(self) -> str:
        return f"E{self.code}: {self.message}"

    def format(self, **kwargs) -> str:
        return str(self).format(**kwargs)


class ErrorLogger:
    """Add special handling for Error to logging.Logger"""

    _suppressions = set()

    @classmethod
    def suppress(cls, e: Union[Error, int, str]):
        """
        Add error to suppression list.
        :returns: True if the suppression was added, False if the error is not
        suppressable.
        """
        if isinstance(e, int):
            e = Error.from_code(e)
        elif isinstance(e, str):
            e = Error.from_name(e)

        assert isinstance(e, Error)

        if e.suppressable:
            cls._suppressions.add(e.code)
            return True
        else:
            logging.warning(f"Not suppressable: {e.name} ({e.code})")
            return False

    @classmethod
    def log(cls, e: Error, **kwargs) -> bool:
        """
        Log an ODR error.
        :returns: True if the error was logged, False if it was suppressed
        """
        if e.code in cls._suppressions:
            return False

        logging.log(e.severity, e.format(**kwargs))
        return True


class SymbolProperty(Enum):
    """Checkable property of a symbol"""

    size = auto()
    type = auto()
    hash = auto()


class FileType(Enum):
    """File type"""

    UNKNOWN = auto()
    PROGRAM = auto()
    OBJFILE = auto()
    ARCHIVE = auto()


class MaybeFile:
    """Wrap something that might be a file"""

    def __init__(self, path: str, orig_path: Optional[str] = None):
        self._path = path
        self._orig_path = orig_path
        self._type = FileType.UNKNOWN

        if os.path.isfile(path):
            _, ext = os.path.splitext(path)
            if ext in Const.OBJECT_FILE_EXTS:
                self._type = FileType.OBJFILE
            elif ext in Const.ARCHIVE_EXTS:
                self._type = FileType.ARCHIVE
            elif self._is_executable_file(path):
                self._type = FileType.PROGRAM
        elif path in self._get_exec_files():
            self._type = FileType.PROGRAM

    @property
    def path(self) -> str:
        """Get the path"""
        return self._path

    @property
    def original_path(self) -> Optional[str]:
        """
        If the file was extracted from an archive, the original_path is the path to the
        archive joined with the name of the file inside the archive. If the file was
        not extracted, then original_path is None.
        """
        return self._orig_path

    @property
    def type(self) -> FileType:
        """Get the filetype"""
        return self._type

    def __str__(self) -> str:
        type_ = str(self._type).split(".")[-1]
        return f"{self._path} ({type_})"

    def extract(
        self, accept: Optional[Callable[[str], bool]] = None
    ) -> Iterable[Tuple[str, str]]:
        """Generator to extract files from an archive into a temporary directory.
        :param accept: If given, callable which indicates whether to extract a
        PathLike. If not given, extract every file.
        :yields: Tuples of (1) file's extracted path and (2) the file's original path.
        :note: The temporary directory will be auto-deleted at program exit."""
        if self.type != FileType.ARCHIVE:
            yield self
            return

        def rmdir(p: PathLike):
            shutil.rmtree(p)
            logging.debug(f"Deleted {p}")

        if not accept:
            accept = lambda _: True  # noqa: 731

        with arpy.Archive(self.path) as ar:
            logging.info(f"Extracting object files from {self.path}")

            tmpdir = tempfile.mkdtemp()
            atexit.register(partial(rmdir, tmpdir))

            for file in ar.namelist():
                file_str = file.decode("utf-8")
                if not accept(file_str):
                    logging.debug(f"Filter rejected {self.path}/{file_str}")
                    continue

                orig_path = os.path.join(self.path, file_str)
                save_path = os.path.join(tmpdir, file.decode("utf-8"))
                with open(save_path, "wb") as wfp, ar.open(file) as rfp:
                    wfp.write(rfp.read())
                    logging.debug(f"Extracted {self.path}/{file_str} to {save_path}")
                yield save_path, orig_path

    @staticmethod
    def _is_executable_file(path: PathLike) -> bool:
        """Indicate whether path names an existing file the current user can
        execute."""
        return os.access(path, os.X_OK)

    @staticmethod
    def _is_listable_dir(path: PathLike) -> bool:
        """Indicate whether path names an existing directory the current user can
        list."""
        return os.path.isdir(path) and os.access(path, os.X_OK | os.R_OK)

    # NB: the lru_cache decorator caches the return value on the first call and
    # subsequent calls return the cached value. This is so we don't scan the PATH
    # directories repeatedly.
    # Technically the return value could change between calls, e.g. if the user
    # toggles the X permission on something, and this would not be reflected in the
    # cached value. But in this case, the worst thing that can happen is when we try
    # to execute a program passed on the command-line, the user gets a permission
    # error message, which is the best we could do by detecting it pre-emptively
    # anyway.
    @classmethod
    @lru_cache
    def _get_exec_files(cls) -> Iterable[Tuple[PathLike, str]]:
        """List executable files in PATH."""
        files = set()
        for path_dir in os.get_exec_path():
            if cls._is_listable_dir(path_dir):
                for file in os.listdir(path_dir):
                    path = os.path.join(path_dir, file)
                    if cls._is_executable_file(path):
                        files.add(file)
        return files


class FileSet(Iterable):
    """Set of MaybeFiles. When iterating object files, if a file in the set is
    an archive, the iterator returns the extracted object files."""

    def __init__(self, paths: Iterable[str]):
        self._files: Iterable[MaybeFile] = list(map(MaybeFile, paths))

    @property
    def files(self) -> Iterable[MaybeFile]:
        """Get the files in the set."""
        return self._files

    def __len__(self) -> int:
        return len(self.files)

    def __iter__(self) -> Iterator[MaybeFile]:
        return iter(self.files)

    def __getitem__(self, i: int) -> MaybeFile:
        return self.files[i]

    @cached_property
    def objects(self) -> Iterable[MaybeFile]:
        """Generator to list object files in the set, including objects extracted
        from archive files."""

        def is_objfile(p):
            _, ext = os.path.splitext(p)
            return ext in Const.OBJECT_FILE_EXTS

        for file in self._files:
            if file.type == FileType.OBJFILE:
                yield file
            elif file.type == FileType.ARCHIVE:
                for save_path, orig_path in file.extract(is_objfile):
                    yield MaybeFile(save_path, orig_path)


class Symbol:
    ...


class Binary:
    """Interface to binary."""

    def __init__(self, binary: Union[str, MaybeFile]):
        self._orig_name = None
        if isinstance(binary, MaybeFile):
            assert binary.type == FileType.OBJFILE
            self._orig_name = binary.original_path
            binary = binary.path
        assert isinstance(binary, str)
        self._binary = lief.parse(binary)

    @property
    def binary(self) -> ELF.Binary:
        """Get the wrapped binary."""
        return self._binary

    @property
    def name(self):
        """Get the binary's name."""
        if self._orig_name is None:
            return self._binary.name
        return self._orig_name

    @property
    def sections(self):
        """Get the sections in the binary."""
        return self._binary.sections

    @cached_property
    def symbols(self):
        """Get the symbols in the binary."""
        return list(map(partial(Symbol, self), self.binary.symbols))

    @cached_property
    def global_symbols(self):
        """Get the global symbols in the binary."""
        return list(filter(lambda s: s.is_global, self.symbols))


class OdrCheckResult:
    """Represents the result of an ODR check between two symbols.
    :param what: The property that was checked, e.g. size.
    """

    def __init__(
        self,
        symbol1: Symbol,
        symbol2: Symbol,
        property: SymbolProperty,
        value1: Any,
        value2: Any,
    ):
        self.property = property
        self.symbol1 = symbol1
        self.symbol2 = symbol2
        self.value1 = value1
        self.value2 = value2

    @property
    def binary1(self):
        return self.symbol1.binary

    @property
    def binary2(self):
        return self.symbol2.binary

    @cached_property
    def error(self) -> Error:
        if self.property == SymbolProperty.size:
            if self.value1 == 0 or self.value2 == 0:
                return Error.ZeroSize
            elif self.value1 != self.value2:
                return Error.SizeMismatch
        elif self.property == SymbolProperty.type:
            if self.value1 != self.value2:
                return Error.TypeMismatch
        elif self.property == SymbolProperty.hash:
            if self.value1 is None or self.value2 is None:
                return Error.NoContent
            elif self.value1 != self.value2:
                return Error.HashMismatch
        else:
            raise ValueError(f"Bug: No case for property={self.property}")

        return None


class Symbol:  # noqa: F811
    """Interface to a symbol."""

    def __init__(self, binary: Binary, symbol: ELF.Symbol):
        self._binary = binary
        self._symbol = symbol

    @property
    def binary(self) -> Binary:
        """Get the binary."""
        return self._binary

    @property
    def symbol(self) -> ELF.Symbol:
        """Get the wrapped symbol"""
        return self._symbol

    @property
    def type(self):
        """Get the symbol's type"""
        return self.symbol.type

    @property
    def size(self):
        """Get the symbol's size"""
        return self.symbol.size

    @property
    def is_global(self) -> bool:
        """Indicate whether the symbol has global binding"""
        return self.symbol.binding in (
            ELF.SYMBOL_BINDINGS.GLOBAL,
            ELF.SYMBOL_BINDINGS.GNU_UNIQUE,
        )

    @property
    def name(self) -> str:
        """Get the symbol's name, may be mangled"""
        return self.symbol.name

    @cached_property
    def friendly_name(self) -> str:
        """Get the symbol's "friendly" name which includes the demangled version
        if the usual name is mangled"""
        if (
            self.symbol.demangled_name in (None, "")
            or self.symbol.name == self.symbol.demangled_name
        ):
            return self.symbol.name
        return f"{self.symbol.demangled_name} ({self.symbol.name})"

    @cached_property
    def section(self):
        """Get the symbol's section"""
        # Documentation says there is a section property on Symbol, which may be None,
        # but at least some versions don't seem to have this property at all.
        # In either case we can work around the problem by using the section index.
        section = getattr(self.symbol, "section", None)
        if section:
            return section
        if self.symbol.shndx < 0 or self.symbol.shndx >= len(self.binary.sections):
            logging.warning(
                " ".join(
                    (
                        "Bad section index",
                        str(self.symbol.shndx),
                        "for",
                        self.name,
                        "in",
                        self.binary.name,
                    )
                )
            )
            return None
        return self.binary.sections[self.symbol.shndx]

    @cached_property
    def content(self) -> bytes:
        """Get the contents of the symbol as bytes"""
        if self.section is None:
            return None

        if self.symbol.size == 0:
            return None

        slice_ = slice(self.symbol.value, self.symbol.value + self.symbol.size)
        return self.section.content[slice_]

    @cached_property
    def content_hash(self) -> str:
        """Get the hexadecimal representation of the hash of the symbol's content,
        or None if the symbol as no content."""
        if self.content is None:
            return None
        hashfn = hashlib.new(Const.HASHFN)
        hashfn.update(self.content)
        return hashfn.hexdigest()

    def compare_with(self, other: Symbol) -> Iterable[OdrCheckResult]:
        """
        Compare properties with another Symbol and return the results of each check
        :returns: The results, even if the properties checked are the same.
        """
        return [
            OdrCheckResult(self, other, SymbolProperty.type, self.type, other.type),
            OdrCheckResult(self, other, SymbolProperty.size, self.size, other.size),
            OdrCheckResult(
                self, other, SymbolProperty.hash, self.content_hash, other.content_hash
            ),
        ]


class OdrChecker:
    """Check for ODR violations within a set of files"""

    def __init__(self):
        self._binaries = None
        self._symbols = None
        self._collisions = None
        self._violations = 0

    @property
    def binaries(self):
        return self._binaries

    @property
    def collisions(self):
        return self._collisions

    @property
    def violations(self):
        return self._violations

    def check(self, fileset: Union[Iterable[str], FileSet]) -> int:
        """Perform checks on a set of object files and return the number of
        violations."""
        self._reset()
        self._preprocess(fileset)
        self._do_check()
        return self._violations

    def _reset(self):
        """Reset internal state"""
        self._violations = 0
        self._binaries = []
        self._symbols = defaultdict(list)
        self._collisions = set()

    def _preprocess(self, fileset):
        """Parse the ELF objects within the fileset and find symbols which may have
        violations"""
        if isinstance(fileset, FileSet):
            self._binaries = list(map(Binary, fileset.objects))
        else:
            assert hasattr(fileset, "__iter__")
            self._binaries = list(map(Binary, fileset))

        for binary in self._binaries:
            for symbol in binary.global_symbols:
                logging.debug(f"Found global symbol {symbol.name} in {binary.name}")
                self._symbols[symbol.name].append(symbol)

                if len(self._symbols[symbol.name]) > 1:
                    self._collisions.add(symbol.name)

    def _do_check(self):
        """Perform the ODR checks"""
        if not self._collisions:
            logging.info("No global symbol names collide, ODR violation is impossible")
            return 0

        for symbol_name in sorted(self._collisions):
            sorted_symbols: Iterable[Symbol] = sorted(
                self._symbols[symbol_name], key=lambda s: s.binary.name
            )

            for i, symbol1 in enumerate(sorted_symbols[:-1]):
                for symbol2 in sorted_symbols[i + 1 :]:  # noqa: E203
                    for result in symbol1.compare_with(symbol2):
                        if not result.error:
                            continue
                        error_logged = ErrorLogger.log(
                            result.error,
                            symbol1=result.symbol1.name,
                            symbol2=result.symbol1.name,
                            binary1=result.binary1.name,
                            binary2=result.binary2.name,
                            value1=result.value1,
                            value2=result.value2,
                        )
                        self._violations += 1 if error_logged else 0

        if self._violations > 0:
            s = "" if self._violations == 1 else "s"
            logging.error(f"{self._violations} ODR violation{s} detected")
            return self._violations

        logging.info("No ODR violation detected")


@dataclass
class ParseArgsResult:
    """Result returned by parse_args"""

    # Program name/path or None
    program: Optional[str]

    # Program arguments
    args: Iterable[str]

    # Object files found within command-line
    objects: FileSet

    # Log level
    level: int

    # Suppression list
    suppressions: Iterable[Error]


def parse_args() -> ParseArgsResult:
    """
    Parse command-line arguments up to first non-flag argument.
    :returns: A ParseArgsResult.
    """
    argv0 = os.path.basename(sys.argv[0])
    ap = argparse.ArgumentParser(
        epilog=Const.EPILOG.format(argv0).strip(),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    supp = ", ".join(map(lambda e: f"{e.name} ({e.code})", Error.get_suppressable()))

    ap.add_argument("-v", action="store_true", help="show informational messages")
    ap.add_argument("-vv", action="store_true", help="show debug messages")
    ap.add_argument(
        "-s",
        action="append",
        metavar="",
        dest="suppress",
        help="Suppress warning. Pass 0 or more times. May be comma separated. "
        + f"Choose from: {supp}",
    )
    ap.add_argument("command", nargs=argparse.REMAINDER)
    args = ap.parse_args()

    if not args.command:
        ap.print_usage(sys.stderr)
        exit(255)

    if args.vv:
        level = logging.DEBUG
    elif args.v:
        level = logging.INFO
    else:
        level = logging.WARNING

    suppressions = set()
    for spr in args.suppress:
        for spr2 in spr.split(","):
            try:
                suppressions.add(Error.from_code(int(spr2)))
            except ValueError:
                try:
                    suppressions.add(Error.from_name(spr2))
                except ValueError:
                    logging.error(f"Got bad error code/name: {spr2}")
                    logging.debug(traceback.format_exc())
                    exit(255)

    args = args.command
    files = FileSet(args)
    program = None
    if files[0].type == FileType.PROGRAM:
        program = args[0]
        args = args[1:]
    objects = FileSet(
        f.path for f in files if f.type in (FileType.OBJFILE, FileType.ARCHIVE)
    )

    return ParseArgsResult(
        program=program,
        args=args,
        objects=objects,
        level=level,
        suppressions=suppressions,
    )


def init_logging(level: int) -> None:
    """Set up logging with coloredlogs if it is available and stderr is a TTY.
    Otherwise use the built-in logging library."""
    if os.isatty(sys.stderr.fileno()):
        try:
            import coloredlogs
        except ImportError:
            logging.basicConfig(format=Const.LOG_FORMAT, level=level, stream=sys.stderr)
        else:
            coloredlogs.install(fmt=Const.LOG_FORMAT, level=level, stream=sys.stderr)
    else:
        logging.basicConfig(format=Const.LOG_FORMAT, level=level, stream=sys.stderr)


def enable_suppressions(suppressions: Iterable[Error]):
    """Enable suppressions from list"""
    if not suppressions:
        return
    for err in suppressions:
        ErrorLogger.suppress(err)


def main():
    parse_result = parse_args()
    init_logging(parse_result.level)
    enable_suppressions(parse_result.suppressions)

    logging.info(
        f"Found {len(parse_result.objects)} object files: "
        + ", ".join(map(lambda o: o.path, parse_result.objects))
    )

    checker = OdrChecker()
    status = checker.check(parse_result.objects)
    if status:
        return min(254, status)

    if parse_result.program:
        command = " ".join((parse_result.program, *parse_result.args))
        logging.info(f"Executing {command}")
        status = os.system(command)
        return status
    else:
        logging.debug(f"{parse_result.program} is not a program, stopping")
        return 0


if __name__ == "__main__":
    status = main()
    exit(status)
